{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OoYAU77gxcy2"
      },
      "outputs": [],
      "source": [
        "!wget https://www.di.ens.fr/~lelarge/fast_train_a100.pt"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "device = torch.device('cuda') # or 'cpu' or 'mps'\n",
        "\n",
        "checkpoint_path = 'fast_train_a100.pt'\n",
        "checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)"
      ],
      "metadata": {
        "id": "KQLVzBlFxgZq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!mkdir nanogpt"
      ],
      "metadata": {
        "id": "6TIclMcTx_NK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%cd nanogpt"
      ],
      "metadata": {
        "id": "3d2DAr0hyg0V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!wget https://raw.githubusercontent.com/dataflowr/notebooks/refs/heads/master/llm/nanogpt/model_inf.py"
      ],
      "metadata": {
        "id": "pp9ZEiukyiiX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!wget https://github.com/dataflowr/notebooks/raw/refs/heads/master/llm/nanogpt/meta.pkl"
      ],
      "metadata": {
        "id": "FRk8KvfqynbU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%cd .."
      ],
      "metadata": {
        "id": "FCvimR9OyyuS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from nanogpt.model_inf import GPTConfig, GPT\n",
        "import pickle"
      ],
      "metadata": {
        "id": "kpvnnQJoy5Je"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "gptconf = GPTConfig(**checkpoint[\"model_args\"])"
      ],
      "metadata": {
        "id": "bdxWj9y5y6lm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model = GPT(gptconf)"
      ],
      "metadata": {
        "id": "mnIiBl3LzPGd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "state_dict = checkpoint[\"model\"]\n",
        "unwanted_prefix = \"_orig_mod.\"\n",
        "for k, v in list(state_dict.items()):\n",
        "    if k.startswith(unwanted_prefix):\n",
        "        #print(k)\n",
        "        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)"
      ],
      "metadata": {
        "id": "dWGjX_cuzRig"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model.load_state_dict(state_dict)"
      ],
      "metadata": {
        "id": "l_J0ffCbzWh1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "T0n2DGOjzY-C"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}